from header import *    # from .header import * will raise error
from .coherence import *
from .logic import *
from .fluency import *
from .topic import *
from .nli import *
from .diversity import *
from .mmi import *
from .bert_multiview import *
from .bertmc import *
from .lccc_lm import *

class MultiView(nn.Module):
    
    '''
    Multi-view metric for Open-domain dialog systems
    1. Coherence
    2. Fluency
    3. Safety
    4. NLI
    5. topic

    Multi-view metric model have following applications in this repo:
    1. rerank the responses generated by GPT2
    2. better evaluator for the RL-based GPT2 fine-tuning
    3. rerank the final responses (GPT2, retrieval, MRC, KBQA)

    MultiView metric/model only predict and do not train it
    '''

    def __init__(self, nli=False, coherence=False, length=False,
                 logic=False, topic=False, fluency=False, nidf_tf=False,
                 repetition_penalty=False, distinct=False, bertmcf=False,
                 mmi=False, lccc=None, coherence_path=None, nli_path=None, 
                 logic_path=None, topic_path=None, mmi_path=None,
                 bertmcf_path=None, fluency_path=None, 
                 bertmultiview=None, bertmultiview_path=None, lccc_path=None):
        super(MultiView, self).__init__()
        self.mode = {
                'bertmultiview': bertmultiview,
                'coherence': coherence,
                'logic': logic,
                'topic': topic,
                'fluency': fluency,
                'nli': nli,
                'distinct': distinct,
                'length': length,
                'nidf_tf': nidf_tf,
                'mmi': mmi,
                'repetition_penalty': repetition_penalty,
                'bertmcf': bertmcf,
                'lccc': lccc_path,
        }
        self.mode_weight = {
                'bertmultiview': 1,
                'coherence': 1,
                'topic': 1.2,
                'fluency': 0.5,
                'length': 0.4,
                'nidf_tf': 0.6,
                'mmi': 0.5,
                'distinct': 0.6,
                'repetition_penalty': 0.2,
                'bertmcf': 1,
                'lccc': 1,
        }
        self.topic_map = {'电影': 'movie', '美食': 'food', '数码产品': 'electric', '音乐': 'music', '体育': 'sport'}
        # load sub-models
        self.model = {}
        # check the essential path whether exists
        if (topic and not topic_path) or \
                (coherence and not coherence_path) or \
                (fluency and not fluency_path) or \
                (logic and not logic_path) or \
                (bertmultiview and not bertmultiview_path) or \
                (bertmcf and not bertmcf_path) or \
                (lccc and not lccc_path):
            raise Exception(f'[!] essential path is not found')
        for k, v in self.mode.items():
            if not v:
                continue
            else:
                if k == 'topic':
                    self.model['topic'] = ff.load_model(topic_path)
                elif k == 'coherence':
                    self.model['coherence'] = COHERENCE()
                    self.model['coherence'].load_model(coherence_path)
                elif k == 'length':
                    self.model['length'] = Length()
                elif k == 'nidf_tf':
                    self.model['nidf_tf'] = NIDF_TF()
                elif k == 'repetition_penalty':
                    self.model['repetition_penalty'] = RepetitionPenalty()
                elif k == 'logic':
                    self.model['logic'] = LOGIC()
                    self.model['logic'].load_model(logic_path)
                elif k == 'nli':
                    self.model['nli'] = NLI()
                    self.model['nli'].load_model(nli_path)
                elif k == 'mmi':
                    self.model['mmi'] = MMI()
                    self.model['mmi'].load_model(mmi_path)
                elif k == 'distinct':
                    self.model['distinct'] = Distinct()
                elif k == 'fluency':
                    # safety need two models (gpt2, mmi gpt2)
                    self.model['fluency'] = SAFETY_FLUENCY()
                    self.model['fluency'].load_model(fluency_path)
                elif k == 'bertmultiview':
                    self.model['bertmultiview'] = BERT_MULTIVIEW()
                    self.model['bertmultiview'].load_model(bertmultiview_path)
                elif k == 'bertmcf':
                    self.model['bertmcf'] = BERTMCF()
                    self.model['bertmcf'].load_model(bertmcf_path)
                elif k == 'lccc':
                    self.model['lccc'] = LCCCLM(lccc_path, 0, 0.9)
        print(f'[!] init the multview module over, available models are shown as follows:')
        # show the available models
        for k, v in self.mode.items():
            if v: 
                print(f'{k}')

    def topic_scores(self, msg, topic):
        '''
        msg is a string
        '''
        topic = self.topic_map[topic]
        msg = ' '.join(jieba.cut(msg))
        try:
            label, value = self.model['topic'].predict(msg)
        except:
            return False
        label = label[0].replace('__label__', '')
        value = value[0]
        if topic == label:
            return True
        else:
            if value <= 0.4:
                return True
            else:
                return False

    @torch.no_grad()
    def forward(self, context, response, groundtruth=None, topic=None, history=None, bertmultiview_details=False):
        '''
        context: the string of the conversation context
        response: the string of the responses
        topic: a list of the topic of the conversation context
        history: a list of the utterances that are talked by the agent

        run one time, process one batch

        return the scores of the sub-models and the final average score
        average_score, (sub_model_score1, sub_model_score2, ...)
        :average_scores: [batch]
        :sub_model_score[i]: [batch]
        '''
        scores = {k: [] for k, v in self.mode.items() if v}
        for k in scores.keys():
            # fasttext short text classification model predict
            # besides, the string should be tokenized by jieba
            if k == 'topic':
                response_ = [' '.join(jieba.cut(i)) for i in response]
                label, value = self.model[k].predict(response_)
                label = [i[0].replace('__label__', '') for i in label]
                value = [i[0] for i in value]
                rest = []
                for l, t, v in zip(label, topic, value):
                    if l == t:
                        rest.append(v)
                    else:
                        rest.append(1-v)
                scores[k] = rest
            elif k in ['length', 'nidf_tf']:
                scores[k] = self.model[k].scores(response)
            elif k in ['distinct']:
                scores[k] = self.model[k].scores(response, history)
            elif k in ['bertmultiview']:
                scores[k] = self.model[k].scores(context, response, details=bertmultiview_details)    # [list]
            elif k in ['bertmcf']:
                assert groundtruth is not None, 'bertmcf must use the groundtruth'
                scores[k] = self.model[k].scores(context, groundtruth, response)
            elif k in ['lccc']:
                scores[k] = self.model[k].scores(context, response, temperature=0.7)
            else:
                scores[k] = self.model[k].scores(context, response)    # [list]
        average_scores = []    # [batch]
        batch_size = len(context)
        if bertmultiview_details:
            return None, scores
        else:
            for idx in range(batch_size):
                average_scores.append(np.sum([v[idx] * self.mode_weight[key] for key, v in scores.items()]))
            return average_scores, scores
        
# ========= test the evaluation of the generated responses =========
def read_evaluation_data(path):
    dataset = []
    with open(path) as f:
        data = f.read().split('\n\n')
        for dialog in tqdm(data):
            dialog = dialog.strip()
            if not dialog:
                continue
            context, target, candidate = dialog.split('\n')
            context = context[5:].strip('[CLS]').strip('[SEP]').split('[SEP]')
            context = ' [SEP] '.join(context).strip()
            
            target = target[5:].strip('[CLS]').strip('[SEP]').strip()
            candidate = candidate[5:].strip('[CLS]').strip('[SEP]').strip()
            dataset.append((context, target, candidate))
    return dataset

def collect_results(path, model, dataset, batch_size=32):
    '''
    only for bertmcf model in multiview
    '''
    with open(path, 'w') as f:
        for i in tqdm(range(0, len(dataset), batch_size)):
            batch = dataset[i:i+batch_size]
            contexts = [i[0] for i in batch]
            groundtruths = [i[1] for i in batch]
            candidates = [i[2] for i in batch]
            rest = model(contexts, candidates, groundtruth=groundtruths)
            rest = rest[1]['bertmcf']
            for c, g, ca, s in zip(contexts, groundtruths, candidates, rest):
                f.write(f'[Context]: {c}\n[Groundtruth]: {g}\n[Candidate]: {ca}\n[Score]: {s}\n\n')
    print(f'[!] write the results into {path}')

if __name__ == "__main__":
    # CUDA_VISIBLE_DEVICES=0 python -m multiview.multiviews
    model = MultiView(
                topic=False,
                coherence=True,
                length=False,
                nidf_tf=False,
                fluency=False,
                lccc=False,
                repetition_penalty=False,
                mmi=False,
                distinct=False,
                bertmultiview=False,
                bertmcf=False,
                bertmcf_path='ckpt/zh50w/bertmcf/best.pt',
                bertmultiview_path='ckpt/zh50w/bertretrieval_multiview/best.pt',
                mmi_path='ckpt/train_generative/gpt2_mmi/best.pt',
                coherence_path='ckpt/zh50w/bertretrieval/best.pt',
                topic_path='ckpt/fasttext/model.bin',
                fluency_path='ckpt/LM/gpt2lm/best.pt',
                lccc_path='/data/lantian/data/LCCD_GPT',
    )
    
    # dataset = read_evaluation_data('rest/train_generative/gpt2/rest.txt')
    # collect_results('multiview/evaluation_rest.txt', model, dataset)

    responses = [
            '电影比较喜欢我这种类型的泰坦尼克号',
            '我比较喜欢泰坦尼克号这种电影类型的',
            '我还是挺喜欢恐怖电影的',
            '恐怖电影非常刺激',
            '我喜欢打乒乓球',
            '我打乒乓球',
            '我喜欢和朋友爬山', 
            '我要去和朋友爬山', 
            '你喜欢什么电影', 
            '你最喜欢什么电影类型呢',
            '今天天气确实挺好的', 
            '我不想知道你喜不喜欢电影', 
            '中国足球必胜',
            '我了解电影',
            '我不喜欢我不喜欢电影',
            '我' * 500,
            ]
    groundtruths = ['我比较喜欢看科幻片，可以激发我无尽的想象'] * len(responses)
    # test the performance of the multiview metric
    contexts = ['你喜欢什么类型的电影呢'] * len(responses)
    topic = ['movie'] * len(responses) 
    history = ['来分享你最近看过的电影吧', '我最近看了一部恐怖片', '你难道喜欢看恐怖片么']

    rest = model(contexts, responses, groundtruth=groundtruths, topic=topic, history=history, bertmultiview_details=False)
    pprint.pprint(rest)
